class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        n = len(graph)
        p = list(range(n))
        size = [1] * n

        def find(x):
            if p[x] != x:
                p[x] = find(p[x])
            return p[x]

        clean = [True] * n
        for i in initial:
            clean[i] = False

        for i in range(n):
            if not clean[i]:
                continue
            for j in range(i + 1, n):
                if not clean[j]:
                    continue
                if graph[i][j] == 1:
                    pa, pb = find(i), find(j)
                    if pa == pb:
                        continue
                    p[pa] = pb
                    size[pb] += size[pa]

        cnt = collections.Counter()
        mp = {}
        for i in initial:
            s = set()
            for j in range(n):
                if not clean[j]:
                    continue
                if graph[i][j] == 1:
                    s.add(find(j))
            for e in s:
                cnt[e] += 1
            mp[i] = s

        mx = -1
        res = 0
        for i, s in mp.items():
            t = 0
            for e in s:
                if cnt[e] == 1:
                    t += size[e]
            if mx < t or (mx == t and i < res):
                mx = t
                res = i
        return res
